import torch
import math
import torch.nn as nn
from config.global_config import GlobalConfig
from config.model_config import TextEncoderConfig
from util import length_to_mask


class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout=0.1, max_len=512):  # max_len: 最大索引长度
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)  # max_len * m
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # max_len * 1:
        # div_term = exp(-log(10000^(i/d))) = exp((i/d) * (-log(10000))) = exp(i * (-log(10000)/d))
        # 注意：i必须是偶数，即(0,2,...,128)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        # p is index of position, i is index of embedding
        # i为偶数：PosEnc(p,i)=sin(p/10000^(i/d))   i为奇数：PosEnc(p,i)=cos(p/10000^(i-1/d))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # max_len * 1 * m
        # 网络进行.cuda()时会加载两种参数，一种是模型参数nn.Parameter，在optim.step中更新；另一种是buffer，在forward中更新.
        self.register_buffer('pe', pe)

    def forward(self, x):  # x: l * b * m
        x = x + self.pe[:x.size(0), :]  # pe[:x.size(0), :] : l * 1 * m
        return x  # x: l * b * m


class TextEncoder(nn.Module):
    """ Text encoder."""

    def __init__(self, config: TextEncoderConfig):
        super(TextEncoder, self).__init__()

        # Embedding.
        self.embedding = nn.Embedding(
            config.vocab_size,
            config.embed_size,
            padding_idx=config.pad_index)
        self.embedding = self.embedding.to(GlobalConfig.device)
        # Initial embedding.
        if config.embed_init is not None:
            self.embedding = self.embedding.from_pretrained(
                config.embed_init, freeze=False)

        # transformer
        self.transformer_layer = nn.TransformerEncoderLayer(config.embed_size, config.num_head, config.dim_feedforward,
                                                            config.t_dropout, config.activation).to(GlobalConfig.device)
        self.transformer_encoder = nn.TransformerEncoder(self.transformer_layer, config.num_layers).to(
            GlobalConfig.device)
        self.position_embedding = PositionalEncoding(config.embed_size, config.t_dropout, config.max_len).to(
            GlobalConfig.device)
        self.linear = nn.Linear(config.embed_size, config.d_model).to(GlobalConfig.device)

    def forward(self, input_seq):
        """Forward.

        Args:
            input_seq: (batch_size, seq_len)
            input_lengths: (batch_size, )

        Returns:
            output: (batch_size, num_layers * num_directions * hidden_size)

        """
        input_seq = input_seq.transpose(0, 1)
        embedded = self.embedding(input_seq)
        embedded = self.position_embedding(embedded)
        # attention_mask = length_to_mask(input_seq.size(0)) == False
        hiddens = self.transformer_encoder(embedded)  # hiddens:torch.Size([31, 64, 300])
        hiddens = self.linear(hiddens)  # hiddens:torch.Size([31, 64, 512])
        output = hiddens[0, :, :]  # output:torch.Size([64, 512])
        return output, hiddens
